import time

import gym
import torch
from agents.interfaces import Learner
from tools.utils import  preprocess, size_action_space


class SubAgent:

    def __init__(self,args,envs,sac,context,rollouts, gamma=0.99, log_interval=100):
        assert not isinstance(envs,Learner),"envs of a coord agent can not be a Learner, it should be a true environment"
        self.context=context
        self.args=args
        self.eval_mode=False
        self.action_space=envs.action_space
        self.action_size = size_action_space(self.action_space)
        self.rollouts = rollouts
        self.log_interval = log_interval
        self.total_num_steps = 1
        self.num_updates = 1 #1 instead of 0 to avoid to early print
        self.envs = envs

        self.episode_rewards = [0.]
        self.start = time.time()
        self.freq_timer = self.start
        self.algo=sac
        self.logger = self.context.logger
        self.old_embeds=None
        self.old_obs=None


    def reset(self,render=False):
        # Reset environment
        obsf=self.envs.reset()
        # if self.args.env_type == "multiworld" or self.args.env_type == "maze":
        state=None
        obs=obsf
        if isinstance(self.envs.observation_space,gym.spaces.Dict):
            obs = obsf["observation"]
        state = torch.tensor(obsf["state"]) if self.args.state else None
        self.lastobs = self.rollouts.init(obs,state=state)

        self.context.features=None

        if render :self.envs.render()
        return obsf


    def step(self,goal,*args,render=False,act_state=None,**kwargs):
        self.total_num_steps +=1
        with torch.no_grad():
            ###Generate actions
            inputs = self.lastobs
            self.running_act= self.algo.act(inputs,goal,features=self.context.features,random_warmup=(self.total_num_steps < self.args.warmup))

        ### Envs steps
        lastobs, reward, done, infos = self.envs.step(self.running_act)
        if render:self.envs.render()

        ###Insert datas
        if isinstance(self.envs.observation_space,gym.spaces.Dict):
            self.lastobs = lastobs["observation"]
        else:
            self.lastobs = lastobs

        self.context.features=None
        with torch.no_grad():
            self.context.features=self.context.estimator.label_embed(preprocess(self.lastobs,self.args),act=True)
        infos = [infos]# if self.args.env_type == "multiworld" or self.args.env_type == "maze" else infos
        self.lastmasks = torch.tensor([not inf_over["over"] for inf_over in infos],dtype=torch.float)

        state = torch.tensor(lastobs["state"]) if self.args.state else None
        if not(self.eval_mode):
            self.rollouts.insert(self.lastobs, infos, self.running_act, self.lastmasks,reward,goal=goal,
                                 features=self.context.features, state=state,act_state=act_state)


        ###Saving data
        if self.eval_mode:self.episode_rewards[-1] += reward[0].item()
        if self.eval_mode and done:self.episode_rewards.append(0.)
        self.context.logger.store(reward=reward.mean())

        return lastobs,self.running_act,reward,done,infos



    def change_goal(self,*args):
        self.rollouts.change_goal()

    def can_learn(self):
        return self.rollouts.can_learn()

    def learn(self):
        if not(self.can_learn()) or self.eval_mode:
            return
        self.num_updates+=1

        batch= self.rollouts.get_evals()
        value_loss,values = 0,0
        next_values,ret=0,0
        if self.args.warmup < self.total_num_steps:
            next_values = self.algo.get_value(batch)
            ret= self.rollouts.compute_returns(next_values, self.args.gamma)
            value_loss, values,_ =self.algo.evaluate(self.rollouts,ret.detach())

        self.logger.store(irewards0=self.rollouts.get_evals().irewards.mean())
        self.logger.store(value_loss=value_loss)
        self.logger.store(values=values)
        self.next_values,self.values,self.ret=next_values,values,ret
        return self.algo

    def eval(self):
        self.eval_mode=True

    def train(self):
        self.eval_mode=False

    def after_update(self):
        self.rollouts.after_update()

    def print(self,**kwargs):
        if self.rollouts.can_learn() and not(self.eval_mode):
            logger = self.logger
            logger.log_tabular("reward",average_only=True)
            logger.log_tabular("nupdates", self.num_updates)
            logger.log_tabular("timesteps", self.total_num_steps)
            logger.log_tabular("ups", (self.log_interval / (time.time() - self.freq_timer)))
            logger.log_tabular("total_time", float(time.time() - self.start))

            # Calculate the fps (frame per second)
            logger.log_tabular("irewards0", average_only=True)
            logger.log_tabular("value_loss",average_only=True)
            logger.log_tabular("values",average_only=True)
            with torch.no_grad():
                embeds = self.rollouts.oegn.embeds
                if self.old_embeds is None:
                    self.old_embeds = torch.zeros_like(embeds)
                nonew_embeds = embeds[:self.old_embeds.shape[0]]
                filter=self.rollouts.oegn.available_nodes[:self.old_embeds.shape[0]]
                # print(embeds.shape,self.old_embeds.shape,filter.shape,self.rollouts.oegn.available_nodes[:self.old_embeds.shape[0]].shape)
                if not filter.any():
                    logger.log_tabular("mean_change", 0)
                    logger.log_tabular("max_change",0)
                else:
                    change_embeds = torch.norm(self.old_embeds[filter] - nonew_embeds[filter], 2, dim=1)
                    mean_change = change_embeds.mean()
                    max_change = change_embeds.max()
                    logger.log_tabular("mean_change",mean_change.item())
                    logger.log_tabular("max_change",max_change.item())
                self.old_embeds = embeds.clone()

                logger.log_tabular("inter-distances",self.rollouts.oegn.compute_mean_interdistance())
                logger.log_tabular("num_buf",self.rollouts.oegn.num_nodes)
                logger.log_tabular("del_buf",self.rollouts.oegn.del_nodes)
                logger.log_tabular("del_closed",self.rollouts.oegn.cpt_close_deleted)
                logger.log_tabular("del_time",self.rollouts.oegn.cpt_time_deleted)
                logger.log_tabular("del_edges",self.rollouts.oegn.del_edges)

                logger.log_tabular("disc_updates",self.context.estimator.disc_updates)

                if self.old_obs is None:
                    self.old_obs_e = self.rollouts.get_evals().next_obs.clone()
                    self.old_embeds_e = self.context.estimator.store.clone()
                actual_embeds = self.context.estimator.embed(self.old_obs_e)
                distances = torch.norm(actual_embeds-self.old_embeds_e,2,dim=1)
                logger.log_tabular("mean_change_e",distances.mean().item())

                logger.log_tabular("max_change_e",distances.max().item())
                norm_successive =torch.norm(self.context.estimator.store-self.context.estimator.prev_store,2,dim=1)
                logger.log_tabular("consecutive_e",norm_successive.mean().item())
                logger.log_tabular("max_consecutive_e",norm_successive.max().item())
                logger.log_tabular("min_consecutive_e",norm_successive.min().item())
                logger.log_tabular("negative_loss",self.context.estimator.save_sum_new_loss.item())

                if self.args.clone_negative > 0:
                    logger.log_tabular("target_change",self.context.estimator.mean_ch_target.item())

                if self.context.estimator.interval_gradient > 0:
                    logger.log_tabular("mean_gradient",self.context.estimator.log_gradients.mean().item())
                    logger.log_tabular("var_gradient",self.context.estimator.log_gradients.std().item())


                self.old_obs_e = self.rollouts.get_evals().next_obs.clone()
                self.old_embeds_e = self.context.estimator.store.clone()

            logger.dump_tabular()


        self.freq_timer = time.time()
        self.intrinsic_reward = 0.
        self.episode_rewards=[0.]


    def save(self):
        self.algo.save()
        self.rollouts.save()

    def load(self):
        self.rollouts.load()
        self.algo.load()

    def clone(self, envs,**kwargs):
        rollouts=self.rollouts.clone()
        return SubAgent(self.args,envs,self.algo,self.context,rollouts,log_interval=self.log_interval)

